import os
import re
import shutil

import torch
import torch.nn as nn

from allennlp.data import Vocabulary
from allennlp.training import Checkpointer
from tools.utils import create_dir_of_file


class Saver():
    def __init__(self, ckpt_id, d_ckpt='ckpt1'):
        self.d_ckpt = d_ckpt
        self.p_ckpt = os.path.join(d_ckpt, ckpt_id)
        os.makedirs(self.p_ckpt, exist_ok=True)

    def save(self, model):
        serialization_files = os.listdir(self.p_ckpt)
        if 'best.th' not in serialization_files:
            torch.save(model.state_dict(), os.path.join(self.p_ckpt, 'last.th'))
        model.vocab.save_to_files(os.path.join(self.p_ckpt, 'vocab'))
        self.clear_useless_checkpoint()

    def load_best_epoch(self, model, params):
        file_name = f'best.th'
        return self.load(file_name, model, params)

    def load_last_epoch(self, model, params):
        file_name = self.find_last_epoch_filename()
        return self.load(file_name, model, params)

    def load_epoch(self, epoch, model, params):
        file_name = f'model_state_epoch_{epoch}.th'
        return self.load(file_name, model, params)

    def load(self, file_name, model, params):
        path = os.path.join(self.p_ckpt, 'vocab')
        vocab = Vocabulary.from_files(path)
        params['vocab'] = vocab
        m = model(**params)
        path = os.path.join(self.p_ckpt, file_name)
        m.load_state_dict(torch.load(path))
        return m

    def find_last_epoch_filename(self):
        serialization_files = os.listdir(self.p_ckpt)
        if 'last.th' in serialization_files:
            return 'last.th'
        model_checkpoints = [x for x in serialization_files if "model_state_epoch" in x]
        found_epochs = [
            re.search(r"model_state_epoch_([0-9\.\-]+)\.th", x).group(1) for x in model_checkpoints
        ]

        int_epochs = []
        for epoch in found_epochs:
            pieces = epoch.split(".")
            if len(pieces) == 1:
                # Just a single epoch without timestamp
                int_epochs.append([int(pieces[0]), "0"])
            else:
                # has a timestamp
                int_epochs.append([int(pieces[0]), pieces[1]])
        last_epoch = sorted(int_epochs, reverse=True)[0]
        if last_epoch[1] == "0":
            epoch_to_load = str(last_epoch[0])
        else:
            epoch_to_load = "{0}.{1}".format(last_epoch[0], last_epoch[1])
        return f'model_state_epoch_{epoch_to_load}.th'

    def clear_useless_checkpoint(self):
        useless_patterns = ['training_state_epoch', 'log']
        for x in os.listdir(self.p_ckpt):
            for name in useless_patterns:
                if name in x:
                    p_del = os.path.join(self.p_ckpt, x)
                    if os.path.isdir(p_del):
                        shutil.rmtree(p_del)
                    else:
                        os.remove(p_del)

        useless_patterns = ['bert']
        p_vocab = os.path.join(self.p_ckpt, 'vocab')
        for x in os.listdir(p_vocab):
            for name in useless_patterns:
                if name in x:
                    os.remove(os.path.join(p_vocab, x))
